
package com.gitee.jmash.rbac.client.shiro;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.eclipse.microprofile.jwt.Claims;
import org.eclipse.microprofile.jwt.JsonWebToken;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;

/**
 * Serializable for DefaultJWTCallerPrincipal.
 *
 * @author CGD
 *
 */
public class JmashJsonWebToken implements JsonWebToken, Serializable {

  private static final long serialVersionUID = 1L;

  private static final Log log = LogFactory.getLog(JmashJsonWebToken.class);

  private Map<String, Object> claimsMap;

  private String tokenType;

  public JmashJsonWebToken() {
    this(new JwtClaims());
  }

  public JmashJsonWebToken(String tokenType, JwtClaims claimsSet) {
    this.claimsMap = claimsSet.getClaimsMap();
    this.tokenType = tokenType;
  }

  public JmashJsonWebToken(JwtClaims claimsSet) {
    this.claimsMap = claimsSet.getClaimsMap();
  }

  @Override
  public String getName() {
    String principalName = getClaim(Claims.upn.name());
    if (principalName == null) {
      principalName = getClaim(Claims.preferred_username.name());
      if (principalName == null) {
        principalName = getClaim(Claims.sub.name());
      }
    }
    return principalName;
  }

  @Override
  public Set<String> getClaimNames() {
    return claimsMap.keySet();
  }

  @SuppressWarnings("unchecked")
  @Override
  public <T> T getClaim(String claimName) {
    Claims claimType = getClaimType(claimName);
    Object claim = null;

    // Handle the jose4j NumericDate types and
    switch (claimType) {
      case exp:
      case iat:
      case auth_time:
      case nbf:
      case updated_at:
        try {
          claim = getClaimValue(claimType.name(), Long.class);
          if (claim == null) {
            claim = 0L;
          }
        } catch (MalformedClaimException e) {
          log.error(claimName, e);
        }
        break;
      case groups:
        claim = getStringSetClaimValue(claimName);
        break;
      case aud:
        claim = getStringSetClaimValue(claimName);
        break;
      case UNKNOWN:
        claim = claimsMap.get(claimName);
        break;
      default:
        claim = claimsMap.get(claimType.name());
    }
    return (T) claim;
  }

  /** Set for ClaimValue . */
  public Set<String> getStringSetClaimValue(String claimName) {
    HashSet<String> sets = new HashSet<>();
    try {
      List<String> globalGroups = getStringListClaimValue(claimName);
      if (globalGroups != null) {
        sets.addAll(globalGroups);
      }
    } catch (MalformedClaimException e) {
      log.error(e);
    }
    return sets;
  }

  /**
   * Gets the value of the claim as a List of Strings, which assumes that it is a JSON array of
   * strings.
   *
   * @param claimName the name of the claim.
   * @return a {@code List<String>} with the values of the claim. Empty list, if the claim is not
   *         present.
   * @throws MalformedClaimException if the claim value is not an array or is an array that contains
   *         non string values
   */
  @SuppressWarnings("rawtypes")
  public List<String> getStringListClaimValue(String claimName) throws MalformedClaimException {
    List listClaimValue = getClaimValue(claimName, List.class);
    return toStringList(listClaimValue, claimName);
  }

  @SuppressWarnings("rawtypes")
  private List<String> toStringList(List list, String claimName) throws MalformedClaimException {
    if (list == null) {
      return Collections.emptyList();
    }
    List<String> values = new ArrayList<>();
    for (Object object : list) {
      try {
        values.add((String) object);
      } catch (ClassCastException e) {
        throw new MalformedClaimException(
            "The array value of the '" + claimName + "' claim contains non string values ", e);
      }
    }
    return values;
  }

  /** ClaimValue. */
  public <T> T getClaimValue(String claimName, Class<T> type) throws MalformedClaimException {
    Object o = claimsMap.get(claimName);
    try {
      return type.cast(o);
    } catch (ClassCastException e) {
      throw new MalformedClaimException(
          "The value of the '" + claimName + "' claim is not the expected type ", e);
    }
  }
  
  public String getTokenType() {
    return tokenType;
  }

  protected Claims getClaimType(String claimName) {
    Claims claimType;
    try {
      claimType = Claims.valueOf(claimName);
    } catch (IllegalArgumentException e) {
      claimType = Claims.UNKNOWN;
    }
    return claimType;
  }

}
